-
Notifications
You must be signed in to change notification settings - Fork 35
Use Threads.nthreads() * 2
in TSVI
#936
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Benchmark Report for Commit 460a65eComputer Information
Benchmark Results
|
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #936 +/- ##
=======================================
Coverage 82.92% 82.92%
=======================================
Files 36 36
Lines 3964 3964
=======================================
Hits 3287 3287
Misses 677 677 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Pull Request Test Coverage Report for Build 15389029544Warning: This coverage report may be inaccurate.This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.
Details
💛 - Coveralls |
Pull Request Test Coverage Report for Build 15216502921Details
💛 - Coveralls |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm happy with this as a hacky workaround, would only propose documenting why this is the way it is.
How would TuringLang/Turing.jl#2555 affect TSVI?
@@ -9,7 +9,9 @@ struct ThreadSafeVarInfo{V<:AbstractVarInfo,L} <: AbstractVarInfo | |||
logps::L | |||
end | |||
function ThreadSafeVarInfo(vi::AbstractVarInfo) | |||
return ThreadSafeVarInfo(vi, [Ref(zero(getlogp(vi))) for _ in 1:Threads.nthreads()]) | |||
return ThreadSafeVarInfo( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add a comment here explaining the situation and maybe linking to this PR or the relevant issue? The * 2
would otherwise appear quite mysterious.
By making it opt-in, it would allow us to use In contrast, now TSVI is mandatory whenever |
Oh right, because you can specify an AbstractVarInfo when you make a LogDensityFunction. Got it. |
Optimistically, I might be able to get that done by the end of this week, but if I find it's too hard I'll come back to this! |
6b9b332
to
460a65e
Compare
I don't think the Turing PR will be done any time soon, so I figure we should just merge this. I added a comment as requested. |
DynamicPPL.jl documentation for PR #936 is available at: |
See #924 for background.
This PR adopts a similar approach to solution (1), i.e., using
Threads.maxthreadid()
. However, Mooncake can't differentiatemaxthreadid()
, so this is a hacky workaround to a hacky workaround, based on the observation thatmaxthreadid()
seems to be upper-bounded bynthreads() * 2
.Personally, I would be much more in favour of removing TSVI (or making it opt-in), and will probably do this in a separate PR once TuringLang/Turing.jl#2555 is solved.
But this is a quick enough fix that should ensure that TSVI continues to work on Julia 1.12 (even if for the wrong reasons). Whether we merge this will probably depend on how fast that Turing issue can be fixed -- if it's soon, then we can remove TSVI, if it's not soon, then this can plug the gap.
Note that if TSVI is opt-in, then we can just use
maxthreadid()
, thereby removing one layer of hackiness — because Mooncake doesn't work with multithreaded execution anyway, the scenarios where you'd want to opt into TSVI have no overlap with the scenarios where you'd want to use Mooncake.